#####################################################
#                                                   #
#  All the necessary functions to  analyse and load #
# the data from the quantum jumps experiment.       #
#                                                   #
#####################################################

import sys, os
import time as time

import h5py
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans

import scipy.special as spec
from scipy.optimize import curve_fit
from scipy.integrate import cumtrapz

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

# Plotting configuration
sns.set_style("ticks")
plt.rcParams['axes.linewidth'] = 3
plt.rcParams["font.size"] = 15  ## to get the customized size of the plot 
plt.rcParams['figure.figsize'] = [5,5]
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
print("Default colour cycle: " +str([color for color in colors]))


##########################################
######### Data loading and saving ########
##########################################

def find_path_from_timestamp(timestamp:str, relative_path:str='')->str:
    with open(f'{relative_path}hdf5_paths.txt', 'r') as f:
        file_paths = f.read()
    file_paths = file_paths.split('\n')
    return [f for f in file_paths if timestamp in f]

def getfiles(dirpath, extension, withstring = None):
    """
    Gets all the files in 'dirpath' wih the file extension 'extension'
    use double backslashes and add a double backslash at the end of dirpath
    """
    a = [s for s in os.listdir(dirpath)
         if os.path.isfile(os.path.join(dirpath, s))]
    a.sort(key=lambda s: os.path.getmtime(os.path.join(dirpath, s)))
    if withstring == None:
        return [i for i in a if i.endswith(extension)]
    else:
        return [i for i in a if i.endswith(extension) if withstring in i]


def getfiles_hdf5(dirpath, withstring = None):
    a = [s for s in os.listdir(dirpath)
         if os.path.isfile(os.path.join(dirpath, s))]
    a.sort(key=lambda s: os.path.getmtime(os.path.join(dirpath, s)))
    if withstring == None:
        return [i for i in a if i.endswith('.hdf5')]
    else:
        return [i for i in a if i.endswith('.hdf5') if withstring in i]


## load into a dictionary together with all the keys
def load_h5_to_dic(fullpath):
    with h5py.File(fullpath, 'r') as file:
        main_keys = list(file["/"].keys())
        data_vector = {}
        if isinstance(file[main_keys[0]], h5py.Dataset):
            #datasets_keys_list = [main_keys]
            for key in main_keys:
                data_vector[key]=file[key][()]
            return data_vector, main_keys
        else:
            datasets_keys_list = {}
            for j, key in enumerate(main_keys):
                datasets_keys = list(file[key].keys())
                datasets_keys_list[key]=list(file[key].keys())
                data_vector[key]={}
                for d_key in datasets_keys:
                    data_vector[key][d_key]=file[key][d_key][()]
            return data_vector, datasets_keys_list 


def save_h5(fullpath,datasets,group=None, overwrite=False, compression="gzip"):
    
    h5py.get_config().track_order = True
    
    if overwrite:
        stringh5 = 'w'
    else:
        stringh5 = 'a'
    with h5py.File(fullpath, stringh5, track_order=True) as fileH5:  # open file in append mode
        if group:
            g = fileH5.create_group(str(group), track_order=True)  # create a data group corresponding to the sweeped parameter
            for key in datasets:
                g.create_dataset(str(key), data=datasets[key], compression=compression)  # create a dataset corresponding to the click array
        else:
            for key in datasets:
                fileH5.create_dataset(str(key), data=datasets[key], compression=compression)  # create a dataset corresponding to the click array

##########################################
######### Utility functions ##############
##########################################

def PrintStatic(s):
    sys.stdout.flush()
    sys.stdout.write(s + " " * (78 - len(s)) + "\r")
    
def better_sleep(t_sleep):
    """Better way of pausing a cell and annoying manu
    Prints progress and allows interrupt"""
    if t_sleep<=0: return
    PrintStatic(("Annoying Manu for %i s of %s s")%(0, t_sleep))
    for counter in range(int(t_sleep)):
        time.sleep(1)
        PrintStatic(("Annoying Manu for %i s of %s s")%(counter+1, t_sleep))
    time.sleep(t_sleep%int(t_sleep))
    print(("Annoyed Manu for %g s                     \n")%(t_sleep))  
    
def calculate_field_xyz(field_value_mT,theta,phi,psi):
    # components of the field in the magnet referential
    Bx = field_value_mT * (np.sin(theta) * np.sin(phi) - np.cos(theta) * np.sin(psi) * np.cos(phi))
    By = field_value_mT * (np.sin(theta) * np.cos(phi) + np.cos(theta) * np.sin(psi) * np.sin(phi))
    Bz = field_value_mT * np.cos(theta) * np.cos(psi)
    return(Bx,By,Bz)
    
def field_angle_to_cartesian(field, angle_degrees):
    angle_radians = angle_degrees * np.pi / 180
    field_y = np.sin(angle_radians) * field
    field_z = np.cos(angle_radians) * field
    return [0, field_y, field_z]
    
def downsample(data,window):
    return(np.mean(data[0:len(data)-int(len(data)%window)].reshape(-1,window),axis=1))

def estimate_fwhm(x_data, y_data, invert = False):
    # Baseline Correction
    if invert: y_data=-y_data
    
    peak_index = np.argmax(y_data)
    baseline = np.mean(np.concatenate([y_data[:int(0.1 * len(y_data))], y_data[int(0.9 * len(y_data)):]]))
    corrected_y = y_data - baseline

    # Find Peak and Half Maximum
    peak_value = corrected_y[peak_index]
    half_max = peak_value / 2

    # Find Nearest Points to Half Maximum
    left_idx = np.where(corrected_y[:peak_index] <= half_max)[0][-1]
    right_idx = np.where(corrected_y[peak_index:] <= half_max)[0][0] + peak_index

    # Calculate FWHM
    fwhm = x_data[right_idx] - x_data[left_idx]
    return fwhm

def calc_chirp_rate(start_freq,stop_freq,pulse_duration):
    """
    Converts a frequency interval (Hz) to a chirp rate, based on a pulse duration (given in ns!)
    returns an integer with the chirp rate in Hz/sec
    """
    return int((stop_freq-start_freq)*1e-3/(pulse_duration*1e-9))


def sinhspace(start,stop,npoints, nonlinearity = 2*np.e): 
    """
    Returns a hyperbolic sinh range with the same syntax as numpy.linspace
    points will be symmetrically spaced about the middle of the range
    nolinearity=0 gives a linear sweep.
    Larger values enhance the number of points in the middle.
    Useful for quickly sweeping over a resonance peak at a known location!!
    """
    if start==stop:
        return(np.ones(npoints)*start)

    elif nonlinearity!=0:
        centre = (start+stop)/2
        sweeprange = stop-start
        sweep = np.sinh(nonlinearity*sweeprange/abs(sweeprange)*np.linspace(-1,1,npoints))
        sweep = centre+0.5*abs(sweeprange)*sweep/max(sweep)
        return(sweep)
    else:
        return(np.linspace(start,stop,npoints))


def sinhspace_asymm(start,stop,npoints, nonlinearity = None): 
    """
    Returns a hyperbolic sinh range with the same syntax as numpy.linspace
    points will be asymmetrically spaced, with more points at smaller values
    nolinearity=0 gives a linear sweep.
    Larger values enhance the number of points at values closer to zero.
    For default nonlinearity this behaves very similar to a logarithmic sweep
    When start==0 default nonlinearity acts similar to a lograithmic sweep over 2 orders of magnitude
    Useful for T1, T2 and similar measurements
    """
    if nonlinearity == None and start !=0: nonlinearity = np.log(abs(stop/start))
    elif nonlinearity == None: nonlinearity = np.log(100)
    if start==stop:
        return(np.ones(npoints)*start)
    elif nonlinearity!=0:
        centre = (start+stop)/2
        sweeprange = stop-start
        if abs(stop)>abs(start):sweep = np.sinh(nonlinearity*sweeprange/abs(sweeprange)*np.linspace(0,1,npoints))
        else: sweep = np.sinh(nonlinearity*sweeprange/abs(sweeprange)*np.linspace(1,0,npoints))
        sweep = sweep-sweep[0]
        sweep = sweep/sweep[-1]
        sweep = start+sweeprange*sweep
        return(sweep)
    else:
        return(np.linspace(start,stop,npoints))
 
 
def kmeans_4d(data):
    '''
    the dimensions of data should be [N, 4, 4]. N being the number of averages, 4 the preparation states and 4 the readout counts.
    '''
    init = np.array([
            [ 100, 50, 50, 50],
            [ 50, 100, 50, 50],
            [ 50, 50, 100, 50],
            [ 50, 50, 50, 100]
    ])
    original_shape = data.shape[:-1]
    X=data.reshape(np.cumprod(original_shape)[-1], 4)
    kmeans = KMeans(n_clusters=4, random_state=0, init=init, n_init=1)

    Z = kmeans.fit_predict(X).reshape(original_shape)
    probs=np.array([(Z==j).mean(0) for j in range(4)])
    return probs, kmeans



def extract_populations_4state(click_NRO, frequency_domain = False, accumulated = False, kmeans = None):
    """takes an array of ramsey readout clicks and, assuming there are 4 possible states, extracts the population probabilities of each
    frequency_domain = False assumes time-domain readout and discriminates in phase
    frequency_domain = True assumes frequency seletcive readout, discriminating along a single frequency axis"""
    
    # We need to do the mean or the sum over two axes for this to work, here we do the first
    if accumulated: XY_NRO = click_NRO
    else: XY_NRO = click_NRO.sum(-2)
        
    # Now we find two differential quantities allowing us to define two axes along which we will discriminate states
    if frequency_domain:
        delta_x = np.take(XY_NRO, 0, -1)+np.take(XY_NRO, 1, -1)-np.take(XY_NRO, 2, -1)-np.take(XY_NRO, 3, -1)
        delta_y = np.take(XY_NRO, 0, -1)+np.take(XY_NRO, 2, -1)-np.take(XY_NRO, 1, -1)-np.take(XY_NRO, 3, -1)
    else: 
        delta_x = np.take(XY_NRO, 0, -1)-np.take(XY_NRO, 2, -1)
        delta_y = np.take(XY_NRO, 1, -1)-np.take(XY_NRO, 3, -1)
    
    # Now we take the 4 possible quadrants of the two differential quantities defined above to get 4 populations
    q0 = np.logical_and(delta_x>0,delta_y>0).mean(0)
    q1 = np.logical_and(delta_x>0,delta_y<0).mean(0)
    q2 = np.logical_and(delta_x<0,delta_y>0).mean(0)
    q3 = np.logical_and(delta_x<0,delta_y<0).mean(0)
    
    # If a kmeans object is passed as an argument, we use this to classify points in a 2D space instead of quadrants
    if kmeans is not None:
        if kmeans.n_features_in_ == 2:
            original_shape = delta_x.shape

            flatten_x = delta_x.flatten()
            flatten_y = delta_y.flatten()

            pop = kmeans.predict(np.transpose([flatten_x, flatten_y]))
        elif  kmeans.n_features_in_ == 4:
            original_shape = click_NRO.shape[:-1]
            Z = click_NRO.reshape(np.cumprod(original_shape)[-1], 4)
            pop = kmeans.predict(Z)
            
        pop = pop.reshape(original_shape)
        
        q0 = (pop==0).mean(0)
        q1 = (pop==1).mean(0)
        q2 = (pop==2).mean(0)
        q3 = (pop==3).mean(0)
        
    return(q0,q1,q2,q3,delta_x,delta_y)


def extract_error_4state(click_NRO, frequency_domain = False, accumulated = False, kmeans = None):
    """takes an array of ramsey readout clicks and, assuming there are 4 possible states, extracts the population probabilities of each
    frequency_domain = False assumes time-domain readout and discriminates in phase
    frequency_domain = True assumes frequency seletcive readout, discriminating along a single frequency axis"""
    
    # We need to do the mean or the sum over two axes for this to work, here we do the first
    if accumulated: XY_NRO = click_NRO
    else: XY_NRO = click_NRO.sum(-2)
        
    # Now we find two differential quantities allowing us to define two axes along which we will discriminate states
    if frequency_domain:
        delta_x = np.take(XY_NRO, 0, -1)+np.take(XY_NRO, 1, -1)-np.take(XY_NRO, 2, -1)-np.take(XY_NRO, 3, -1)
        delta_y = np.take(XY_NRO, 0, -1)+np.take(XY_NRO, 2, -1)-np.take(XY_NRO, 1, -1)-np.take(XY_NRO, 3, -1)
    else: 
        delta_x = np.take(XY_NRO, 0, -1)-np.take(XY_NRO, 2, -1)
        delta_y = np.take(XY_NRO, 1, -1)-np.take(XY_NRO, 3, -1)
    
    # Now we take the 4 possible quadrants of the two differential quantities defined above to get 4 populations
    q0 = np.logical_and(delta_x>0,delta_y>0).std(0) / np.sqrt(delta_x.shape[0])
    q1 = np.logical_and(delta_x>0,delta_y<0).std(0) / np.sqrt(delta_x.shape[0])
    q2 = np.logical_and(delta_x<0,delta_y>0).std(0) / np.sqrt(delta_x.shape[0])
    q3 = np.logical_and(delta_x<0,delta_y<0).std(0) / np.sqrt(delta_x.shape[0])
    
    # If a kmeans object is passed as an argument, we use this to classify points in a 2D space instead of quadrants
    if kmeans is not None:
        if kmeans.n_features_in_ == 2:
            original_shape = delta_x.shape

            flatten_x = delta_x.flatten()
            flatten_y = delta_y.flatten()

            pop = kmeans.predict(np.transpose([flatten_x, flatten_y]))
        elif  kmeans.n_features_in_ == 4:
            original_shape = click_NRO.shape[:-1]
            Z = click_NRO.reshape(np.cumprod(original_shape)[-1], 4)
            pop = kmeans.predict(Z)
            
        pop = pop.reshape(original_shape)
        
        q0 = (pop==0).std(0) / np.sqrt(delta_x.shape[0])
        q1 = (pop==1).std(0) / np.sqrt(delta_x.shape[0])
        q2 = (pop==2).std(0) / np.sqrt(delta_x.shape[0])
        q3 = (pop==3).std(0) / np.sqrt(delta_x.shape[0])
        
    return(q0,q1,q2,q3)


##########################################
######### Fitting functions ##############
##########################################


def fit_function(guess,func,xdata,ydata,lb = -np.inf,ub = np.inf, extend = False):
    """Wrapper around curve_fit to conveniently do all the stuff you wish it already did
    returns est,std,fine,data_fit"""
    
    est,cov = curve_fit(func, xdata, ydata,p0 = guess, bounds = (lb, ub))
    std = np.sqrt(np.diag(cov))
    #print (est,std)
    #
    if not extend:
        fine = np.linspace(min(xdata),max(xdata),len(xdata)*20)
        data_fit=func(fine,*est)
        return (est,std,fine,data_fit)
    else:
        delta = max(xdata) - min(xdata)
        fine = np.linspace(min(xdata) - extend*delta,max(xdata) + extend*delta,len(xdata)*10*(2*extend+1))
        data_fit=func(fine,*est)
        return (est,std,fine,data_fit)

def bootstrap(fit_function, x, y, guess, n_sampling=100, plot=False):
    
    if plot:
        plt.figure()
    fit, _ = curve_fit(fit_function, x, y, guess)
    
    popt_list = []
    for i in range(n_sampling):
        mask = np.random.randint(0,len(x),len(x))
        try:
            boot_x = np.array(x)[mask]
            boot_y = np.array(y)[mask]
            
            popt, pcov = curve_fit(fit_function, boot_x, boot_y, guess)
            popt_list.append(popt)
            
            x_fit = np.linspace(x.min(), x.max(), 100)
            
            if plot:
                plt.plot(x_fit, fit_function(x_fit, *popt), 'r', alpha=0.1)
        except:
            pass
    
    if plot:
        plt.plot(x, y, 'ok')
        plt.plot(x_fit, fit_function(x_fit, *fit), 'k', alpha=0.7, linewidth=5)
    plt.xlim([min(x), max(x)])
    plt.ylim([min(y), max(y)])
    return np.std(np.array(popt_list), axis=0)

def lorentz(delta,delta0,kappa,a,b):
    return a/(1+(delta-delta0)**2/(kappa/2)**2)+b

def lorentz_back(delta,delta0,kappa,a,b,c):
    return a/(1+(delta-delta0)**2/(kappa/2)**2) + b + c*delta

def double_lorentz(delta,delta_a,delta_b,kappa,a,b,c):
    return a/(1+(delta-delta_a)**2/(kappa/2)**2)+b/(1+(delta-delta_b)**2/(kappa/2)**2)+c

def Complex_osc_decay(t,T,f,alpha,beta,phi,a):
    return alpha/2*np.exp(-(t/T)**2-1j*(2*np.pi*f*t+2*np.pi*phi))+beta+1j*beta+(a+1j*a)*t

def exp_decay(t,T,alpha,beta):
    return alpha*np.exp(-t/T)+beta

def rabi_fit(t,f,a,b):
    return a*(b-np.cos(2*np.pi*f*(t)))


def gaussian(x, a, mu, sigma, b):
    return a * np.exp( -np.power( (x-mu) / sigma , 2) / 2) + b

p_decay = lambda x,p, A,B: A*p**x+B



##########################################
######### Plotting functions #############
##########################################



def plot_spectroscopy_fieldsweep(path,timestamp,xlabel = "B [mT]",integration_max = -1):
    file = getfiles(path+timestamp,'.hdf5')[0]
    
    data = load_h5_to_dic(path+timestamp+file)[0]
    
    B0_list   = list(data.keys())
    click     = np.array([data[B0]['click_hist'] for B0 in B0_list])
    time_axis = np.array([data[B0]['time_axis'] for B0 in B0_list])
    
    
    # global variable for one B0 spectroscopy --> taken with the first B0
    theta_list = data[B0_list[0]]['theta']*180/np.pi
    amplitude_list = data[B0_list[0]]['amplitude_pulse']
    
    B0list_list = np.array([float(string) for string in B0_list])
    
    bins = 20
    
    time_hist = time_axis.reshape(time_axis.shape[0], bins, int(time_axis.shape[1]/bins)).mean(-1)*1e-6
    click_hist = click.reshape(click.shape[0], bins, int(time_axis.shape[1]/bins))[:,:integration_max].mean(-1)
    
    if False:
        for ii, B in enumerate(B0_list):
            plt.figure()
            plt.plot(time_hist[ii], click_hist[ii])
            plt.xlabel("Time [us]")
            plt.ylabel("Counts / us")
            plt.title(f"Bz = {B} mT")
            plt.grid()
    clickshape=click_hist.shape[1]
    time_step=(time_hist.mean(0)[1:]-time_hist.mean(0)[:-1]).mean(0)
        
    number_of_counts = click_hist[:,:].sum(-1)*time_step
    number_of_counts_subs = (click_hist[:,:10].sum(-1)-click_hist[:,-1:].sum(-1))*time_step
    
    fig, (ax0, ax1) = plt.subplots(2, 1, tight_layout=True)
    fig.supxlabel(xlabel)
    ax0.plot(B0list_list, number_of_counts)
    ax0.set_ylabel("Number of counts")
    ax0.grid()
    ax1.plot(B0list_list, number_of_counts_subs)
    ax1.set_ylabel("Number of counts subtracted")
    ax1.grid()

def plot_2d_sweep(data,x=[],y=[],xlabel = '',ylabel = '',clabel = '',title = '',xtick = 'auto',ytick = 'auto',
                  centre = None,vmin = None,vmax = None,cmap = sns.diverging_palette(240, 10, n=361),
                  horizontal_ticks = False, fontsize = None,annot=False):
    """
    Generic plotting function for 2D datasets
    """
    
    if len(x)==0: x = np.linspace(0, data.shape[1]-1, data.shape[1], dtype = int)
    if len(y)==0: y = np.linspace(0, data.shape[0]-1, data.shape[0], dtype = int)
    fieldsweep_df = pd.DataFrame(data=np.flip(data,axis = 0),index=np.flip(y,axis = 0),columns=x)
    #else: fieldsweep_df = pd.DataFrame(data=data)
    ax = sns.heatmap(fieldsweep_df, xticklabels = xtick, yticklabels = ytick,cmap = cmap,center = centre,vmin = vmin, vmax = vmax,annot=annot,annot_kws={"fontsize":10},cbar=bool(clabel))#,center = -100,cmap = sns.diverging_palette(240, 10, n=361))
    if clabel:
        ax.collections[0].colorbar.set_label(clabel, fontsize = fontsize)
        ax.collections[0].colorbar.ax.tick_params(labelsize=fontsize)
    plt.tick_params(labelsize = fontsize)
    plt.xlabel(xlabel, fontsize = fontsize)
    plt.ylabel(ylabel, fontsize = fontsize)
    if horizontal_ticks == True:
        plt.xticks(rotation=0)
        plt.yticks(rotation=0)
    plt.title(title, fontsize = fontsize)

   
def basic2plot(data, data_prep, x, xlabel, readout_freqs, kmeans = None):
    
    fig,ax=plt.subplots(4,1,figsize=(10,15),tight_layout=True)

    p_data = (data[:,:,-1]>100).mean(0)
    if len(readout_freqs)==4:
        p0,p1,p2,p3,delta_x,delta_y = extract_populations_4state(data, frequency_domain = True, accumulated=True, kmeans = kmeans)
        pops = [p0,p1,p2,p3]
        labels = [r"${|\uparrow\uparrow\rangle}$",r"${|\uparrow\downarrow\rangle}$",r"${|\downarrow\uparrow\rangle}$",r"${|\downarrow\downarrow\rangle}$"]
    else:
        pops = [p_data]
        labels = [r"${|\downarrow\downarrow\rangle}$"]

    ### Plot 1 - Number of accumulated counts ###
    for i in range(len(readout_freqs)): 
        y = data.mean(0)[:,i]
        dy = data.std(0)[:,i]/np.sqrt(len(data))
        ax[0].errorbar(x,  y, dy, label = labels[i], fmt = "o-", color = colors[i])
        
    ax[0].set_xlabel(xlabel)
    ax[0].set_ylabel("Mean counts")
    ax[0].legend(fontsize = "small", loc = "upper right")
    
    ### Plot 2 - Probability of the state ###
    for l, pop in zip(labels, pops): 
        ax[1].plot(x, pop, "o-", label = l)

    ax[1].set_xlabel(xlabel)
    ax[1].set_ylabel("Population")
    ax[1].set_ylim(0,1)
    ax[1].legend(fontsize = "small", loc = "upper right")

    ### Plot 3 - Histogram ###
    bins=np.arange(20,np.max([np.max(np.concatenate(data)),np.max(data_prep)]),5)
    ax[2].hist(np.concatenate(data,axis = None), bins=bins   , label = r"$measurement$", alpha=0.5)
    ax[2].hist(data_prep, bins=bins, label = r"$preparation$", alpha=0.5)
    ax[2].legend()
    ax[2].set_ylabel("instances")
    ax[2].set_xlabel('counts')

    ### Plot 4 - FFT ###
    fft_x = 1e3*np.fft.rfftfreq(len(x),d=x[1]-x[0])
    fft_y = np.abs(np.fft.rfft(pops[-1] - pops[-1].mean()))
    rabi_freq = fft_x[np.argmax(fft_y)]

    ax[3].plot(fft_x, fft_y)
    ax[3].vlines(rabi_freq, min(fft_y), max(fft_y)*1.1, linestyle='dashed', color='k')
    ax[3].set_xlabel("Frequency (kHz)")
    ax[3].set_ylabel("FFT")
    ax[3].set_title(f"Max frequency = {rabi_freq:.2f} kHz")

    plt.tight_layout()
    
    return ax, pops


def plot_flattop_ft(pulse_duration, center_freq, ramp_time = 0):
    """
    Returns an x and y array with the fourier transform of a flattop pulse (height = 1)
    Pulse duration is given in clock cycles, frequency is in Hz
    """
    pulse_bandwidth = 1/(pulse_duration*4*1e-9)
    around_pulse = 10e6
    vals = np.concatenate(([0]*int(around_pulse/4),chirp_cos_raise(ramp_time,1,0)[0], [1]*pulse_duration, chirp_cos_raise(ramp_time,-1,0)[0]+1,[0]*int(around_pulse/4)))
    times = np.linspace(0,len(vals)*4*1e-6,len(vals))
    fft_x = 1e3*np.fft.rfftfreq(len(times),d=times[1]-times[0])
    fft_y = np.abs(np.fft.rfft(vals))
    full_pulse = np.concatenate((np.flip(fft_y),fft_y))
    full_freq = np.concatenate((-np.flip(fft_x),fft_x))
    return full_freq+center_freq, full_pulse/max(full_pulse)


def plot_populations(x,click_NRO_prep,click_NRO,directory,filename,save = True, xlabel = 'Frequency (kHz)'):
    
    """Plots nucelar spin populations for preparation and readout.
    Assumes 4 state ramsey readout.
    x is the x axis sweep variable against which the 4 populations will be plotted
    Saving is optional."""
    
    (q0,q1,q2,q3,delta_x_prep, delta_y_prep,)=extract_populations_4state(click_NRO_prep)
    plt.figure(figsize = (14,14))
    
    plt.subplot(2,2,1)
    plt.title('Preparation')
    
    plt.plot(x, q0, label=r"$|\downarrow\uparrow  \rangle$")
    plt.plot(x, q1, label=r"$|\downarrow\downarrow\rangle$")
    plt.plot(x, q2, label=r"$|\uparrow  \downarrow\rangle$")
    plt.plot(x, q3, label=r"$|\uparrow  \uparrow  \rangle$")
    plt.xlabel(xlabel)
    plt.ylabel('Probability')
    plt.legend()
    
    
    (q0,q1,q2,q3,delta_x,delta_y)=extract_populations_4state(click_NRO)
    plt.subplot(2,2,2)
    plt.title('Spectroscopy')
    plt.plot(x, q0, label=r"$|\downarrow\uparrow  \rangle$")
    plt.plot(x, q1, label=r"$|\downarrow\downarrow\rangle$")
    plt.plot(x, q2, label=r"$|\uparrow  \downarrow\rangle$")
    plt.plot(x, q3, label=r"$|\uparrow  \uparrow  \rangle$")
    plt.xlabel(xlabel)
    plt.ylabel('Probability')
    plt.legend()
    
    
    plt.subplot(2,2,4)
    plt.scatter(delta_x, delta_y, s=10, alpha = 0.1)
    plt.scatter(delta_x_prep, delta_y_prep, s=10, alpha = 0.1, color = 'red')

    # plt.text( 0.2,  0.35, 'P=%.3f'%q0.mean())
    # plt.text( 0.2, -0.40, 'P=%.3f'%q1.mean())
    # plt.text(-0.4,  0.35, 'P=%.3f'%q2.mean())
    # plt.text(-0.4, -0.40, 'P=%.3f'%q3.mean())
    plt.xlabel(r'$\Delta_\mathrm{clicks}~x$')
    plt.ylabel(r'$\Delta_\mathrm{clicks}~y$')
    # plt.xlim([-0.5, 0.5])
    # plt.ylim([-0.5, 0.5])
    plt.axvline(0, linestyle = "--",color = 'red', alpha = 0.7)
    plt.axhline(0, linestyle = "--",color = 'red', alpha = 0.7)
    plt.grid()
    
    plt.tight_layout()
    if save: 
        try: 
            plt.savefig(directory+filename+'_spectrum.pdf')
        except:
            print("saving failed")
    plt.show()


def kmeans_plot(delta, h=1, ax=None, title=''):
    init = np.array([
        [ 100, 100],
        [ 100,-100],
        [-100, 100],
        [-100,-100]
    ])
    
    # Reshape into a good shape
    X = np.array([delta[0].flatten(), delta[1].flatten()]).astype(int).T
    kmeans = KMeans(n_clusters=4, random_state=0, init=init, n_init=1)
    y_pred = kmeans.fit_predict(X)
    
    # construct mesh
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

    # obtain labels per mesh point (reuse stored model)
    Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    # change the definition of the clusters 
    
    if not ax:
        plt.figure(figsize=(5,5))
        ax = plt.gca()
    
    cmap = mpl.colors.ListedColormap(colors[:4])
    norm = mpl.colors.BoundaryNorm(np.arange(4), cmap.N)
    
    ax.plot(delta[0],delta[1],".", alpha = 0.5)
    ax.imshow(
        Z, interpolation='nearest', cmap=cmap, alpha=0.1,
        extent=(xx.min(), xx.max(), yy.min(), yy.max()),
        aspect='auto', origin='lower',
    )
    ax.set_xlabel(r'Counts $\Delta_x$')
    ax.set_ylabel(r'Counts $\Delta_y$')
    
    probs = []
    for i in range(4):
        pred = kmeans.predict(np.array(delta)[:,:,i].T)
        P = []
        for j in range(4): P.append((pred==j).sum()/len(pred))
        probs.append(P)
        # title += f"$P_{i:d}$={P[i]:.2f},  "
        # if i == 1: title += '\n'
    # ax.set_title(title[:-2], color='k', ha ='center', fontsize='medium')
    ax.set_title(title, color='k', ha ='center', fontsize='medium')
    
    return probs, kmeans




##########################################
######### Pulse generating functions #####
##########################################



def gauss(amplitude, mu, sigma, length):
    t = np.linspace(-length / 2, length / 2, length)
    gauss_wave = amplitude * np.exp(-((t - mu) ** 2) / (2 * sigma ** 2))
    return [float(x) for x in gauss_wave]


def ErfSquarePulseHeight(numberOfPoints, start, stop, sigma, amplitude=1):
    """
    'Square' pulse with gaussian rise and fall, specified by its height
    """
    # start, stop, sigma = convertInPoints(samplingTime, start, stop, sigma)

    t = np.arange(numberOfPoints)

    values = amplitude * ((spec.erf((t - start - 2 * sigma) / sigma) + 1) / 2 - (
                spec.erf((t - stop + 2. * sigma) / sigma) + 1) / 2)
    # values =1

    return values


def ErfRising(numberOfPoints, sigma, amplitude=1):
    t = np.arange(numberOfPoints)
    values = amplitude * (spec.erf((t - numberOfPoints / 2) / sigma) + 1) / 2
    return values


def chirp_cos_raise(length, amp, df):
    t = np.linspace(0, length, abs(length))
    cosfunc = (1-np.cos(np.pi*t/length))/2
    phase = cumtrapz(df*(1-cosfunc**2), x=t, initial=0)
    return amp*cosfunc*np.cos(phase), amp*cosfunc*np.sin(phase)


